import torch
from models.stgnn import STGNN

def predict(model, X, A):
    model.eval()
    with torch.no_grad():
        X = torch.FloatTensor(X).permute(0, 2, 1).unsqueeze(-1)  # [B,F,T,1]
        output = model(X)
        pred = torch.argmax(output, dim=-1)
    return pred